#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import sys
import tempfile
from timeit import default_timer as timer

from tabulate import tabulate

# hackish way to resuse existing start code
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../tests/macro_benchmark/")
from clients import *
from databases import *


def main():
    durability_dir = tempfile.TemporaryDirectory()

    DURABILITY_DIR_ARG = ["--data-directory", durability_dir.name]
    MAKE_SNAPSHOT_ARGS = ["--storage-snapshot-on-exit"] + DURABILITY_DIR_ARG
    RECOVER_SNAPSHOT_ARGS = ["--data-recovery-on-startup"] + DURABILITY_DIR_ARG
    snapshot_memgraph = Memgraph(MAKE_SNAPSHOT_ARGS, 1)
    recover_memgraph = Memgraph(RECOVER_SNAPSHOT_ARGS, 1)
    client = QueryClient(None, 1)

    results = []
    for node_cnt in [10**6, 5 * 10**6, 10**7]:
        for edge_per_node in [0, 1, 3]:
            for prop_per_node in [0, 1]:
                snapshot_memgraph.start()
                properties = "{}".format(",".join(["p{}: 0".format(x) for x in range(prop_per_node)]))
                client(["UNWIND RANGE(1, {}) AS _ CREATE ({{ {} }})".format(node_cnt, properties)], snapshot_memgraph)
                client(
                    ["UNWIND RANGE(1, {}) AS _ MATCH (n) CREATE (n)-[:l]->(n)".format(edge_per_node)], snapshot_memgraph
                )
                snapshot_memgraph.stop()

                # This waits for the snapshot to be recovered and then exits
                start = timer()
                recover_memgraph.start()
                recover_memgraph.stop()
                stop = timer()
                diff = stop - start

                snapshots_dir = os.path.join(durability_dir.name, "snapshots")
                assert len(os.listdir(snapshots_dir)) == 1
                snapshot_file = os.path.join(snapshots_dir, os.listdir(snapshots_dir)[0])
                snap_size = round(os.path.getsize(snapshot_file) / 1024.0 / 1024.0, 2)
                os.remove(snapshot_file)

                edge_cnt = edge_per_node * node_cnt
                results.append((node_cnt, edge_cnt, prop_per_node, snap_size, diff))

    print(
        tabulate(
            tabular_data=results, headers=["Nodes", "Edges", "Properties", "Snapshot size (MB)", "Elapsed time (s)"]
        )
    )


if __name__ == "__main__":
    main()
